{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/bminixhofer/nnsplit/blob/master/train/train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5HlpOyhRdgm4"
   },
   "source": [
    "This notebook shows how to train [NNSplit](https://github.com/bminixhofer/nnsplit/) on a custom dataset and load it for inference."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PF8h9mSzRnCJ"
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XYvCFh9cdrZ3"
   },
   "source": [
    "First, clone the Github Repo and install requirements. If you are running this on Colab, you will likely have to restart the runtime after installing the requirements because of some version mismatches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 121
    },
    "colab_type": "code",
    "id": "smd7hIM_zcLO",
    "outputId": "1d1b9e33-2493-4b59-c491-42fc0da5fd07"
   },
   "outputs": [],
   "source": [
    "!git clone https://www.github.com/bminixhofer/nnsplit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8vFPPwbuQ9v0"
   },
   "source": [
    "# Data preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "N74zUHFLd2aK"
   },
   "source": [
    "Training NNSplit is not limited to a specific dataset. Howevever, I have found the [Linguatools Wikipedia Dumps](https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) to work well, so there is built-in functionality to load those. Feel free to use other data!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VGEolHUkeIHK"
   },
   "source": [
    "First, download the `.xml.bz2` file and unzip it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 384
    },
    "colab_type": "code",
    "id": "uabNv05OneLK",
    "outputId": "d2d0b4d3-1601-4a2e-81dc-40e9048fc536"
   },
   "outputs": [],
   "source": [
    "!wget https://www.dropbox.com/s/cnrhd11zdtc1pic/enwiki-20181001-corpus.xml.bz2?dl=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "L4aHHmAEoj25"
   },
   "outputs": [],
   "source": [
    "!mv enwiki-20181001-corpus.xml.bz2?dl=1 enwiki-20181001-corpus.xml.bz2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "1GY4mHCboaUu"
   },
   "outputs": [],
   "source": [
    "!bzip2 -d enwiki-20181001-corpus.xml.bz2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UArMGRYReVGT"
   },
   "source": [
    "Now we can create the dataset. `xml_dump_iter` is one of the built in methods which yields an iterator over all texts in the wikipedia dump, trying to remove tags and other markup."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "74lGj78HszTo"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"nnsplit/train\")\n",
    "from text_data import MemoryMapDataset, xml_dump_iter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 123
    },
    "colab_type": "code",
    "id": "tQ4Cedjws7e6",
    "outputId": "59052077-0e73-4756-d4d8-8854d114f9f3"
   },
   "outputs": [],
   "source": [
    "xml_iter = xml_dump_iter(\"enwiki-20181001-corpus.xml\", \n",
    "                         min_text_length=300, \n",
    "                         max_text_length=5000)\n",
    "next(xml_iter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KMazX62Pe66B"
   },
   "source": [
    "`MemoryMapDataset` is another convient built-in class, but not specific to the Wikipedia dump. It is a `torch.utils.data.Dataset` which can be created using a `texts.txt` and `slices.pkl` file. The `texts.txt` file is [memory-mapped](https://en.wikipedia.org/wiki/Memory-mapped_file) and `slices.pkl` contains a Python array with indices that determine at which position in the dataset which range of the text should be loaded. This allows accessing each text without ever loading all the data into memory.\n",
    "\n",
    "To create `texts.txt` and `slices.pkl` from an iterator over text, use `MemoryMapDataset.iterator_to_text_and_slices`.\n",
    "\n",
    "Note that this will be quite slow since iterating over the XML dump takes a significant amount of time, so I would recommend caching `texts.txt` and `slices.pkl` somewhere.\n",
    "\n",
    "`max_n_texts=10_000_000` is only needed in Colab to keep disk usage in check, feel free to remove this otherwise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 66,
     "referenced_widgets": [
      "22f62fda33fe4e8f8a65030e50f035d4",
      "9724f0a36fbd47679f549d1ea38b8f87",
      "c800829bcf0043cb906175554c99c7a8",
      "5aec69645e6548c8a651b5430e369d25",
      "b68af80929bd48e6bf09230d26b9fc2c",
      "6ff19989f6684a41a4cf55266c6e25e3",
      "df7e91b8c2ed4f9baa27f582202cff5c",
      "99c513e7d9cb4c2885613142fdabd6c9"
     ]
    },
    "colab_type": "code",
    "id": "zlRunloks4RU",
    "outputId": "6d4f8f2a-377e-495c-9783-4c6d01bab5e9"
   },
   "outputs": [],
   "source": [
    "xml_iter = xml_dump_iter(\"enwiki-20181001-corpus.xml\", \n",
    "                         min_text_length=300,\n",
    "                         max_text_length=5000)\n",
    "MemoryMapDataset.iterator_to_text_and_slices(xml_iter, \n",
    "                                             \"texts.txt\", \n",
    "                                             \"slices.pkl\",\n",
    "                                             max_n_texts=10_000_000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "y1BvQ894gvKJ"
   },
   "source": [
    "Here, I am saving the outputs to my Drive, you will have to adjust these paths."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xQs7qQHHtXVH"
   },
   "outputs": [],
   "source": [
    "!cp -a slices.pkl \"drive/My Drive/Projects/nnsplit/slices.pkl\"\n",
    "!cp -a texts.txt \"drive/My Drive/Projects/nnsplit/texts.txt\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AZpLGc_kPlOq"
   },
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Nh9UKxN4hFuw"
   },
   "source": [
    "Now we can get started with training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ju0AhUATnvRx"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"nnsplit/train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "X96MCCfQRBDG"
   },
   "outputs": [],
   "source": [
    "import json\n",
    "from pytorch_lightning.trainer import Trainer\n",
    "from tqdm.auto import tqdm\n",
    "from model import Network\n",
    "from text_data import MemoryMapDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ogqWNcZkhKEB"
   },
   "source": [
    "NNSplit has a `Network` class which is a `pl.LightningModule` specifying network architecture, data loading logic etc. To instantiate a new network, we need to first get the default hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 54
    },
    "colab_type": "code",
    "id": "mKNM1sVpR-Lq",
    "outputId": "75c52494-2ea4-48ad-cf48-f369eaedfe9a"
   },
   "outputs": [],
   "source": [
    "parser = Network.get_parser()\n",
    "hparams = parser.parse_args([])\n",
    "hparams"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vMldDgFkS3ZU"
   },
   "source": [
    "## Load text data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QE2YMQOrhy24"
   },
   "source": [
    "Next, we can load the text data created previously."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ADW2U60kSnOf"
   },
   "outputs": [],
   "source": [
    "text_dataset = MemoryMapDataset(\"texts.txt\", \"slices.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8YOPVi2vh2dG"
   },
   "source": [
    "Keep in mind that this can be any `torch.utils.data.Dataset` with `str` entries, so you can completely customize it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 123
    },
    "colab_type": "code",
    "id": "hS-KFyiOS8Bj",
    "outputId": "e9b8794c-4931-4877-a91a-37db1160ee36"
   },
   "outputs": [],
   "source": [
    "text_dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GEDFRxU_S-4w"
   },
   "source": [
    "## Load labeler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I5cVdSTFiE9o"
   },
   "source": [
    "Next, create a `Labeler`, which is used to annotate the text from above. Any SpaCy model which supports sentencization can be used. You will have to install the appropriate SpaCy model with `python -m spacy ...` when running this in Colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tkwjwBPiCUPl"
   },
   "outputs": [],
   "source": [
    "from labeler import Labeler, SpacySentenceTokenizer, SpacyWordTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "q568uRViSHco"
   },
   "outputs": [],
   "source": [
    "labeler = Labeler(\n",
    "    [\n",
    "        SpacySentenceTokenizer(\n",
    "            \"en_core_web_sm\", lower_start_prob=0.7, remove_end_punct_prob=0.7\n",
    "        ),\n",
    "        SpacyWordTokenizer(\"en_core_web_sm\"),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aps02am6iio_"
   },
   "source": [
    "`Labeler.visualize` shows you what the network sees: \n",
    "- `byte` is the UTF-8 encoded text. This has changed in the newest version of NNSplit. Previously characters where used, but using bytes allows NNSplit to work for any language regardless of the characters used to represent it.\n",
    "- The other rows depend on the `Labeler` and determine what the neural networks tries to predict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 347
    },
    "colab_type": "code",
    "id": "IL5-_awsSfYX",
    "outputId": "617b0849-268d-42f4-85be-ed81612c4fef"
   },
   "outputs": [],
   "source": [
    "labeler.visualize(\"This is a test. This is another test.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CL7gZKuITDYA"
   },
   "source": [
    "## Start training!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "6_cI7PJ0jNfH"
   },
   "source": [
    "Now we can finally start training. \n",
    "\n",
    "`train_size` determines how many entries in the dataset to sample for each epoch. \n",
    "\n",
    "Using SpaCy with multiprocessing leaks memory, so the memory usage will continously increase during each epoch and reset at the end. So you will have to set `train_size` to a size that corresponds to how much memory is available. `500_000` works well in Colab.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3D14-MKkTUL3"
   },
   "outputs": [],
   "source": [
    "hparams.gpus = 1\n",
    "hparams.max_epochs = 4\n",
    "hparams.train_size = 500_000\n",
    "hparams.predict_indices = [0, 1] # which split levels of the labeler to predict\n",
    "# how to weigh the selected indices\n",
    "# in general sentence boundary detection should be weighed the highest\n",
    "hparams.level_weights = [0.1, 2.0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fq2u1s6Cj5sT"
   },
   "source": [
    "Instantiate the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 121
    },
    "colab_type": "code",
    "id": "GV385cJtTNVN",
    "outputId": "2d378e43-8bf8-49d7-9603-0a2654bd3b26"
   },
   "outputs": [],
   "source": [
    "model = Network(\n",
    "  text_dataset,\n",
    "  labeler,\n",
    "  hparams,\n",
    ")\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OHDW5wbKj8xz"
   },
   "source": [
    "Instantiate the `pl.trainer.Trainer`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 69
    },
    "colab_type": "code",
    "id": "TzV01atSTjKe",
    "outputId": "78241dbd-cb4b-4615-830d-9c628bf6e9cb"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer.from_argparse_args(hparams)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "o6_n-DuNkD0o"
   },
   "source": [
    "And fit the model. Each row of the f1 and precision scores corresponds to each tokenizer of the `Labeler`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 466,
     "referenced_widgets": [
      "3ae606b9b3334fa9b67edfe204d5defa",
      "f40bf8ecf8264ea99fa321c552898862",
      "b49964f347e843369526eebc2847900d",
      "8e2ae9dfe8b64c0099bcafcae88108dc",
      "93adbfe460f64ca997d7765ed253656c",
      "bbf0d1648eae43c0a46a5d933bb4850f",
      "bef7ef4752f1416b97c3760745d3b436",
      "5958137bdce541619a8e33a0a228fc8e",
      "eea4bca6822146ad90a26e2a20349e6e",
      "5a9f9f2fe8004eeda014ec65c247eb6c",
      "af587a9dfe614307b68262c896c0da94",
      "945888de9e3247c4ad3fb6e9c27eeef7",
      "5b8b180011f040d9b249bdddae331faf",
      "a9727c57bc5948e8ba92ab661e4c3f39",
      "b378260d471b4de7b2a2da665de3154c",
      "bd6d9cf15499465ea5a0973ddcd84c3e",
      "4e46dec9c4564566b0f273039b8d5f6e",
      "088d96702c22436f84e7f30086d1b764",
      "98f90816463744ae9ab14423623d8c4f",
      "4b678eb4cbeb47d6b7b2d0b8d0876ef0",
      "b492402c416340418e04faabc989eebd",
      "e29e8c05d39a48a292fb89a9bf844763",
      "05c8249d9dd847599fbbd0c8e91b9b4f",
      "b2d7d208d0694ce1a8f6f2ceb3eb1337",
      "cbf674041db34631b89c7b0efd06de40",
      "233e4c45a82d4fa699171e380321acf2",
      "7a3227d6d43b43eea96fda35400f378f",
      "2e2434b01e124a2a93fc3bf880403e2d",
      "5c36a80489b6486aa36a37ea4b0d5db8",
      "0638213131af43669c8fd45980a4e8f7",
      "4bec521f9ea749c48ba668010dbca59f",
      "940cc32ee98440a998cdf03ccde84ca9",
      "ab0aed4a7b3447749556f8ff8e29985b",
      "238ba8d9c6864e30b21bce77d2d5ae24",
      "cae043072d9646689e41606bd9e43ca4",
      "038f17e97e504eccbe57f7a7c5ac3a9e",
      "54a77d69b07246ebb5b7ca9922ff8d50",
      "cbd21beea6d741d99d2a9a6380719206",
      "330bb02f92df4e89a8a8085a5b91b24e",
      "b479db59506046c187e10e17d49bae61",
      "7c1bf6e624f844bab69e1eccc558497f",
      "80d73557b2fb4f52837d785555d71d93",
      "b6d4586e8cf14e75af81a3ec35214c65",
      "cee4f64f1b3c46f5a1c53f361381f46d",
      "66f73f9a6edd4c32924c16241b511f77",
      "ccff26f6082445c49fdbca68e1225ea6",
      "ac791bee8f164ff6a271e178d614727f",
      "a2c6f6dafd6f49e6bdd06811a449fd38"
     ]
    },
    "colab_type": "code",
    "id": "Awa3ZO5oTl7X",
    "outputId": "4443e683-0d95-4637-f37c-44b64a1fbd43"
   },
   "outputs": [],
   "source": [
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Iq16P_pqklch"
   },
   "source": [
    "Finally, store the trained model somewhere. This saves a `.onnx` export of the model in the specified directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 72
    },
    "colab_type": "code",
    "id": "s5tSj4DGoYfL",
    "outputId": "6e2cdafd-ce08-4b64-b2bf-a583c43124de"
   },
   "outputs": [],
   "source": [
    "# onnx metadata which determines how to use the prediction indices to split text\n",
    "metadata = {\n",
    "    \"split_sequence\": json.dumps(\n",
    "        {\n",
    "            \"instructions\": [\n",
    "                [\"Sentence\", {\"PredictionIndex\": 0}],\n",
    "                [\"Token\", {\"PredictionIndex\": 1}],\n",
    "                [\"_Whitespace\", {\"Function\": \"whitespace\"}],\n",
    "            ]\n",
    "        }\n",
    "    )\n",
    "}\n",
    "model.store(\"en\", metadata)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fiAxbt37YFYP"
   },
   "source": [
    "# Load the model in NNSplit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ynlr1bYuk_ZV"
   },
   "source": [
    "First, install NNSplit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install nnsplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2y8mTjA9YRXH"
   },
   "outputs": [],
   "source": [
    "from nnsplit import NNSplit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "a7tfpxtplE8S"
   },
   "source": [
    "Instantiate the splitter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3EeDhEjoYj-a"
   },
   "outputs": [],
   "source": [
    "splitter = NNSplit(\"en/model.onnx\", use_cuda=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "zrjzXPnPlPDs"
   },
   "source": [
    "And split a text!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 54
    },
    "colab_type": "code",
    "id": "ERhaXDUlYmcU",
    "outputId": "2a6c9244-6b2b-4e3f-c6c1-27d255ac24cb"
   },
   "outputs": [],
   "source": [
    "splits = splitter.split([\"This is a test This is another test.\"])[0]\n",
    "splits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WOi2xVTllSjv"
   },
   "source": [
    "The public API of NNSplit has changed significantly, making it much easier to use now. Everything is a `nnsplit.Split` which can be iterated over or stringified with `str(...)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "colab_type": "code",
    "id": "L5DRnwAJY5Px",
    "outputId": "ae5a91ee-8266-4af3-f48b-0924a561486f"
   },
   "outputs": [],
   "source": [
    "for sentence in splits:\n",
    "    print(str(sentence).ljust(30), type(sentence))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yRt37H-hlmqf"
   },
   "source": [
    "Or if you want to go token-level:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 208
    },
    "colab_type": "code",
    "id": "7mKlF78WY-31",
    "outputId": "f9d3a409-9a98-4088-a32d-ca72ad88e2b2"
   },
   "outputs": [],
   "source": [
    "for sentence in splits:\n",
    "    for token in sentence:\n",
    "        print(str(token).ljust(10), repr(token).ljust(30), type(token))\n",
    "\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "YyVn5FHdluOQ"
   },
   "source": [
    "Until the smallest unit, which then returns a `str` instead of an `nnsplit.Split`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 486
    },
    "colab_type": "code",
    "id": "xTgFJU1uZgnM",
    "outputId": "b8cea330-810e-4310-b609-0b9ded6317da"
   },
   "outputs": [],
   "source": [
    "for sentence in splits:\n",
    "    for [text, whitespace] in sentence:\n",
    "        print(text.ljust(10), type(text))\n",
    "        print(f'\"{whitespace}\"'.ljust(10), type(whitespace))\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wrN4T1lYl3is"
   },
   "source": [
    "Finally, for some benchmarks: If you are running `NNSplit` on GPU, you can increase the speed on large datasets by using a big batch size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uYHT8F79bP89"
   },
   "outputs": [],
   "source": [
    "splitter = NNSplit(\"en/model.onnx\", use_cuda=True, batch_size=2**14)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 86
    },
    "colab_type": "code",
    "id": "0BSrcu55bDQV",
    "outputId": "2fa7a5b3-1562-4ac0-dce7-513c0644b828"
   },
   "outputs": [],
   "source": [
    "text = \"This is a test This is another test.\"\n",
    "\n",
    "%timeit splitter.split([text])[0]\n",
    "%timeit splitter.split([text] * 100)[0]\n",
    "%timeit splitter.split([text] * 1000)[0]\n",
    "%timeit splitter.split([text] * 10_000)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sTGnTBrqmAFh"
   },
   "source": [
    "And voilà! Splitting 10000 short texts in less than 400 milliseconds."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyMYlI5r8AMYsV3nRAzHuQpG",
   "collapsed_sections": [],
   "include_colab_link": true,
   "mount_file_id": "1NwZadkDUSSNlbZ2kvMmLc0LRCIB7qT-g",
   "name": "NNSplit.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
